###############################################################################
## AUTHOR: ALAN YAN
## DATE LAST UPDATED: 07/28/2020
## PURPOSE: HET EFFECTS OF UNION TREATMENT VIA RANDOM FORESTS (OUTCOME == WORK) 
###############################################################################
rm(list = ls())

#### NOTES ####

#(1) create a tidier function for random forest output

#### PACKAGES ####
library(pacman)
p_load(tidyverse,
       estimatr,
       broom,
       haven,
       hrbrthemes,
       cjoint,
       formula.tools, 
       DeclareDesign,
       emmeans,
       lmtest,
       cregg,
       grf)
source("02-src/forest-functions.R")

#### FUNCTIONS ####

# ggplot theme
theme_shom_alt <- function (base_size = 11, waffle = FALSE) 
{
  ret <- theme_minimal(base_size = base_size, base_family = "Roboto Condensed") + 
    theme(plot.background = element_rect(fill = "#f5f5f2", 
                                         color = NA), 
          panel.background = element_rect(fill = "#f5f5f2",color = NA), 
          legend.background = element_rect(fill = "#f5f5f2",color = NA))
  if (waffle) {
    ret + theme(axis.text = element_blank())
  }
  else {
    ret
  }
}

#### LOAD DATA ####
df <- read_rds("01-yougov-conjoint/01-data/cleaned/yougov-conjoint-cleaned.rds")


#### CODE DATA FOR CAUSAL FORESTS ####

df %>%
  mutate(is_college = ifelse(r_edu %in% c("4-year college","Post-grad"), 1, 0),
         is_white = ifelse(r_race == "White", 1, 0),
         is_democrat = ifelse(r_party %in% c(1,2,3), 1, 0),
         is_republican = ifelse(r_party %in% c(5,6,7),1, 0),
         is_independent = ifelse(r_party == 4, 1, 0),
         is_female = ifelse(r_female == "Female", 1, 0),
         is_union = ifelse(r_union_member == "Union member", 1, 0),
         is_employed_ft = ifelse(r_employment == "Full-time", 1, 0),
         is_employed_pt = ifelse(r_employment == "Part-time", 1, 0),
         is_retired = ifelse(r_employment == "Retired", 1, 0),
         is_unemployed = ifelse(r_employment %in% c("Homemaker","Other",
                                                    "Unemployed","Temporarily laid off",
                                                    "Student","Permanently disabled"), 1, 0),
         cje_cg_worker_shareholders = ifelse(cje_corporate_gov == "Workers are shareholders", 1, 0),
         cje_cg_worker_corporate_board = ifelse(cje_corporate_gov == "Workers sit on the corporate board", 1, 0),
         cje_cg_worker_elect_managers = ifelse(cje_corporate_gov == "Workers elect their managers", 1, 0),
         cje_union_unionshop = ifelse(cje_union == "Unionized", 1, 0)) -> df_grf

#### LIST OF COVARIATES ####

covariates <- c("is_college","is_white","is_democrat","is_republican","is_independent",
                "is_female","is_union","is_employed_ft","is_employed_pt",
                "is_retired","is_unemployed","r_age")




##### generalized random forests (work unionized versus not unionized) ####

#estimate model
effects_forest_union_work <- run_causal_forest_cjoint(
  df = df_grf,
  outcome = work_binary,
  treatment_cjoint = cje_union,
  treatment_binary = cje_union_unionshop,
  treatment_comparison = c("Unionized",
                           "Not unionized"),
  covars = covariates
)

#tidy
effects_forest_union_work_td <- tidy_causal_forest(effects_forest_union_work) %>%
  mutate(outcome = "Prefer to Work",
         treatment = "Unionized vs\nNot unionized")

#plot individual-level predictions
effects_forest_union_work_td %>%
  drop_na(party_char) %>%
  ggplot(aes(x = id,y = tau.hat)) +
  geom_linerange(aes(ymin = lower.90,ymax = upper.90),color = "indianred",alpha = 0.7) + 
  geom_linerange(aes(ymin = lower.95,ymax = upper.95),color = "indianred",alpha = 0.3) + 
  geom_hline(aes(yintercept = unique(cate))) + #estimated causal effect full sample
  geom_ribbon(aes(ymin = unique(cate) - 1.96*unique(cate_se),
                  ymax = unique(cate) + 1.96*unique(cate_se)),
              alpha = 0.3) + 
  geom_point(size = 1.5,shape = 21,aes(fill = party_char)) + 
  scale_fill_manual(values = c("dodgerblue","white","indianred")) + 
  theme_shom_alt(base_size = 16) + 
  theme(axis.text.y = element_blank()) + 
  labs(x = "",
       y = "Predicted Effect") + 
  geom_hline(yintercept = 0,linetype = "dashed",size = 1.5) + 
  coord_flip()

##### generalized random forests (power union vs non union) ####

#estimate model
effects_forest_union_power <- run_causal_forest_cjoint(
  df = df_grf,
  outcome = power_binary,
  treatment_cjoint = cje_union,
  treatment_binary = cje_union_unionshop,
  treatment_comparison = c("Unionized",
                           "Not unionized"),
  covars = covariates
)

#tidy
effects_forest_union_power_td <- tidy_causal_forest(effects_forest_union_power) %>%
  mutate(outcome = "More Power",
         treatment = "Unionized vs\nNot Unionized")


#plot individual-level predictions
effects_forest_union_power_td %>%
  drop_na(party_char) %>%
  ggplot(aes(x = id,y = tau.hat)) +
  geom_linerange(aes(ymin = lower.90,ymax = upper.90),color = "indianred",alpha = 0.7) + 
  geom_linerange(aes(ymin = lower.95,ymax = upper.95),color = "indianred",alpha = 0.3) + 
  geom_hline(aes(yintercept = unique(cate))) + #estimated causal effect full sample
  geom_ribbon(aes(ymin = unique(cate) - 1.96*unique(cate_se),
                  ymax = unique(cate) + 1.96*unique(cate_se)),
              alpha = 0.3) + 
  geom_point(size = 1.5,shape = 21,aes(fill = party_char)) + 
  scale_fill_manual(values = c("dodgerblue","white","indianred")) + 
  theme_shom_alt(base_size = 16) + 
  theme(axis.text.y = element_blank()) + 
  labs(x = "",
       y = "Predicted Effect") + 
  geom_hline(yintercept = 0,linetype = "dashed",size = 1.5) + 
  coord_flip()

##### generalized random forests (responsibility unions versus no unions) ####

#estimate model
effects_forest_union_resp <- run_causal_forest_cjoint(
  df = df_grf,
  outcome = responsibility_binary,
  treatment_cjoint = cje_union,
  treatment_binary = cje_union_unionshop,
  treatment_comparison = c("Unionized",
                           "Not unionized"),
  covars = covariates
)

#tidy
effects_forest_union_resp_td <- tidy_causal_forest(effects_forest_union_resp) %>%
  mutate(outcome = "More Responsibilities",
         treatment = "Unionized vs\nNot Unionized")

#plot individual-level predictions
effects_forest_union_resp_td %>%
  drop_na(party_char) %>%
  ggplot(aes(x = id,y = tau.hat)) +
  geom_linerange(aes(ymin = lower.90,ymax = upper.90),color = "indianred",alpha = 0.7) + 
  geom_linerange(aes(ymin = lower.95,ymax = upper.95),color = "indianred",alpha = 0.3) + 
  geom_hline(aes(yintercept = unique(cate))) + #estimated causal effect full sample
  geom_ribbon(aes(ymin = unique(cate) - 1.96*unique(cate_se),
                  ymax = unique(cate) + 1.96*unique(cate_se)),
              alpha = 0.3) + 
  geom_point(size = 1.5,shape = 21,aes(fill = party_char)) + 
  scale_fill_manual(values = c("dodgerblue","white","indianred")) + 
  theme_shom_alt(base_size = 16) + 
  theme(axis.text.y = element_blank()) + 
  labs(x = "",
       y = "Predicted Effect") + 
  geom_hline(yintercept = 0,linetype = "dashed",size = 1.5) + 
  coord_flip()

##### generalized random forests (elect-managers versus private ownership) ####

#estimate model
effects_forest_union_complaints <- run_causal_forest_cjoint(
  df = df_grf,
  outcome = complaints_binary,
  treatment_cjoint = cje_union,
  treatment_binary = cje_union_unionshop,
  treatment_comparison = c("Unionized",
                           "Not unionized"),
  covars = covariates
)

#tidy
effects_forest_union_complaints_td <- tidy_causal_forest(effects_forest_union_complaints) %>%
  mutate(outcome = "Better Handle Complaints",
         treatment = "Unionized vs\nNot Unionized")


#plot individual-level predictions
effects_forest_union_complaints_td %>%
  drop_na(party_char) %>%
  ggplot(aes(x = id,y = tau.hat)) +
  geom_linerange(aes(ymin = lower.90,ymax = upper.90),color = "indianred",alpha = 0.7) + 
  geom_linerange(aes(ymin = lower.95,ymax = upper.95),color = "indianred",alpha = 0.3) + 
  geom_hline(aes(yintercept = unique(cate))) + #estimated causal effect full sample
  geom_ribbon(aes(ymin = unique(cate) - 1.96*unique(cate_se),
                  ymax = unique(cate) + 1.96*unique(cate_se)),
              alpha = 0.3) + 
  geom_point(size = 1.5,shape = 21,aes(fill = party_char)) + 
  scale_fill_manual(values = c("dodgerblue","white","indianred")) + 
  theme_shom_alt(base_size = 16) + 
  theme(axis.text.y = element_blank()) + 
  labs(x = "",
       y = "Predicted Effect") + 
  geom_hline(yintercept = 0,linetype = "dashed",size = 1.5) + 
  coord_flip()


#### PLOT BY PARTY ID (ALL) ####

#bind results
plot_tbl <- bind_rows(effects_forest_union_work_td,
                      effects_forest_union_power_td,
                      effects_forest_union_resp_td,
                      effects_forest_union_complaints_td) %>%
  mutate(outcome = fct_relevel(outcome,
                               "Prefer to Work",
                               "More Power",
                               "More Responsibilities",
                               "Better Handle Complaints")) %>%
  mutate(outcome = paste("Outcome:\n",outcome,sep = " "),
         outcome = fct_relevel(outcome,
                               "Outcome:\n Prefer to Work",
                               "Outcome:\n More Power",
                               "Outcome:\n More Responsibilities",
                               "Outcome:\n Better Handle Complaints"))

#### PLOT ACROSS ALL OUTCOMES (FIGURE S6) ####

#plot individual-level predictions
plot_tbl %>%
  drop_na(party_char) %>%
  ggplot(aes(x = id,y = tau.hat)) +
  facet_wrap(~outcome) + 
  geom_linerange(aes(ymin = lower.90,ymax = upper.90,color = party_char),alpha = 0.2) + #CIs around point estimates
  geom_linerange(aes(ymin = lower.95,ymax = upper.95,color = party_char),alpha = 0.01) + 
  geom_hline(aes(yintercept = cate)) + #estimated causal effect full sample
  geom_rect(aes(ymin = cate - 1.96*cate_se,
                ymax = cate + 1.96*cate_se,
                xmin = 1, xmax = max(plot_tbl$id),
                x = NULL,y = NULL),
            alpha = 0.1) + 
  geom_point(size = 1.5,shape = 21,aes(fill = party_char)) +#estimated unit level causal effect
  scale_fill_manual(values = c("dodgerblue","white","indianred")) + 
  scale_color_manual(values = c("dodgerblue","white","indianred")) + 
  theme_shom_alt(base_size = 16) + 
  theme(axis.text.y = element_blank(),
        plot.caption = element_text(size = 10,hjust = 1),
        plot.caption.position = "plot",
        strip.background = element_rect(fill="grey"),
        strip.text = element_text(color = 'black'),
        legend.position = "bottom",
        strip.text.x = element_text(size = 11),
        strip.text.y = element_text(size = 11)) + 
  labs(x = "",
       y = "Predicted Effect of Unionized vs Non-Unionized Firm",
       fill = "Party ID",
       color = "Party ID",
       caption = "Notes: Estimates for the treatment effect for each trial where the comparison is between union vs non-union workplaces are ranked by magnitude and generated via causal forests. 
       The thin vertical line represent the estimated causal effect of each treatment with the vertical shaded regions representing the 95% confidence interval. 
       The horizontal shaded regions represent 95% confidence intervals of individual level treatment effect estimates.") + 
  geom_hline(yintercept = 0,linetype = "dashed",size = 1.5) + 
  coord_flip() + 
  ggsave("01-yougov-conjoint/03-src/output/figures/hte-union-pid.pdf",
         dpi = 320,width = 11,height = 14,device = cairo_pdf)


